{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3.1 表格数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  7.        ,   0.27000001,   0.36000001, ...,   0.44999999,\n",
       "          8.80000019,   6.        ],\n",
       "       [  6.30000019,   0.30000001,   0.34      , ...,   0.49000001,\n",
       "          9.5       ,   6.        ],\n",
       "       [  8.10000038,   0.28      ,   0.40000001, ...,   0.44      ,\n",
       "         10.10000038,   6.        ],\n",
       "       ..., \n",
       "       [  6.5       ,   0.23999999,   0.19      , ...,   0.46000001,\n",
       "          9.39999962,   6.        ],\n",
       "       [  5.5       ,   0.28999999,   0.30000001, ...,   0.38      ,\n",
       "         12.80000019,   7.        ],\n",
       "       [  6.        ,   0.20999999,   0.38      , ...,   0.31999999,\n",
       "         11.80000019,   6.        ]], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import csv\n",
    "import numpy as np\n",
    "wine_path = \"../../data/chapter3/winequality-white.csv\"\n",
    "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\",\n",
    "     skiprows=1)\n",
    "wineq_numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4898, 12),\n",
       " ['fixed acidity',\n",
       "  'volatile acidity',\n",
       "  'citric acid',\n",
       "  'residual sugar',\n",
       "  'chlorides',\n",
       "  'free sulfur dioxide',\n",
       "  'total sulfur dioxide',\n",
       "  'density',\n",
       "  'pH',\n",
       "  'sulphates',\n",
       "  'alcohol',\n",
       "  'quality'])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "col_list = next(csv.reader(open(wine_path), delimiter=';'))\n",
    "wineq_numpy.shape, col_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898, 12]), 'torch.FloatTensor')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "wineq = torch.from_numpy(wineq_numpy)\n",
    "wineq.shape, wineq.type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 7.0000,  0.2700,  0.3600,  ...,  3.0000,  0.4500,  8.8000],\n",
       "         [ 6.3000,  0.3000,  0.3400,  ...,  3.3000,  0.4900,  9.5000],\n",
       "         [ 8.1000,  0.2800,  0.4000,  ...,  3.2600,  0.4400, 10.1000],\n",
       "         ...,\n",
       "         [ 6.5000,  0.2400,  0.1900,  ...,  2.9900,  0.4600,  9.4000],\n",
       "         [ 5.5000,  0.2900,  0.3000,  ...,  3.3400,  0.3800, 12.8000],\n",
       "         [ 6.0000,  0.2100,  0.3800,  ...,  3.2600,  0.3200, 11.8000]]),\n",
       " torch.Size([4898, 11]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = wineq[:, :-1] # 除最后一列外所有列\n",
    "data, data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([6., 6., 6.,  ..., 6., 7., 6.]), torch.Size([4898]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target = wineq[:, -1] # 最后一列\n",
    "target, target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6, 6, 6,  ..., 6, 7, 6])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target = wineq[:, -1].long()\n",
    "target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 独热编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'target' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Input \u001B[1;32mIn [4]\u001B[0m, in \u001B[0;36m<cell line: 2>\u001B[1;34m()\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m  \u001B[38;5;21;01mtorch\u001B[39;00m\n\u001B[1;32m----> 2\u001B[0m target_onehot \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mzeros(\u001B[43mtarget\u001B[49m\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m      3\u001B[0m target_onehot\u001B[38;5;241m.\u001B[39mscatter_(\u001B[38;5;241m1\u001B[39m, target\u001B[38;5;241m.\u001B[39munsqueeze(\u001B[38;5;241m1\u001B[39m), \u001B[38;5;241m1.0\u001B[39m)\n\u001B[0;32m      5\u001B[0m target\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]\n",
      "\u001B[1;31mNameError\u001B[0m: name 'target' is not defined"
     ]
    }
   ],
   "source": [
    "target_onehot = torch.zeros(target.shape[0], 10)\n",
    "target_onehot.scatter_(1, target.unsqueeze(1), 1.0)\n",
    "target.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[6],\n",
       "        [6],\n",
       "        [6],\n",
       "        ...,\n",
       "        [6],\n",
       "        [7],\n",
       "        [6]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_unsqueezed = target.unsqueeze(1)\n",
    "target_unsqueezed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6.8548e+00, 2.7824e-01, 3.3419e-01, 6.3914e+00, 4.5772e-02, 3.5308e+01,\n",
       "        1.3836e+02, 9.9403e-01, 3.1883e+00, 4.8985e-01, 1.0514e+01])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_mean = torch.mean(data, dim=0)\n",
    "data_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7.1211e-01, 1.0160e-02, 1.4646e-02, 2.5726e+01, 4.7733e-04, 2.8924e+02,\n",
       "        1.8061e+03, 8.9455e-06, 2.2801e-02, 1.3025e-02, 1.5144e+00])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_var = torch.var(data, dim=0)\n",
    "data_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.7209e-01, -8.1764e-02,  2.1325e-01,  ..., -1.2468e+00,\n",
       "         -3.4914e-01, -1.3930e+00],\n",
       "        [-6.5743e-01,  2.1587e-01,  4.7991e-02,  ...,  7.3992e-01,\n",
       "          1.3467e-03, -8.2418e-01],\n",
       "        [ 1.4756e+00,  1.7448e-02,  5.4378e-01,  ...,  4.7502e-01,\n",
       "         -4.3677e-01, -3.3662e-01],\n",
       "        ...,\n",
       "        [-4.2042e-01, -3.7940e-01, -1.1915e+00,  ..., -1.3131e+00,\n",
       "         -2.6152e-01, -9.0544e-01],\n",
       "        [-1.6054e+00,  1.1666e-01, -2.8253e-01,  ...,  1.0048e+00,\n",
       "         -9.6250e-01,  1.8574e+00],\n",
       "        [-1.0129e+00, -6.7703e-01,  3.7852e-01,  ...,  4.7502e-01,\n",
       "         -1.4882e+00,  1.0448e+00]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_normalized = (data - data_mean) / torch.sqrt(data_var)\n",
    "data_normalized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(20))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bad_indexes = torch.le(target, 3)\n",
    "bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([20, 11])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bad_data = data[bad_indexes]\n",
    "bad_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 0 fixed acidity          7.60   6.89   6.73\n",
      " 1 volatile acidity       0.33   0.28   0.27\n",
      " 2 citric acid            0.34   0.34   0.33\n",
      " 3 residual sugar         6.39   6.71   5.26\n",
      " 4 chlorides              0.05   0.05   0.04\n",
      " 5 free sulfur dioxide   53.33  35.42  34.55\n",
      " 6 total sulfur dioxide 170.60 141.83 125.25\n",
      " 7 density                0.99   0.99   0.99\n",
      " 8 pH                     3.19   3.18   3.22\n",
      " 9 sulphates              0.47   0.49   0.50\n",
      "10 alcohol               10.34  10.26  11.42\n"
     ]
    }
   ],
   "source": [
    "bad_data = data[torch.le(target, 3)]\n",
    "# 对于numpy数组和PyTorch张量，＆运算符执行逻辑和运算\n",
    "mid_data = data[torch.gt(target, 3) & torch.lt(target, 7)]\n",
    "good_data = data[torch.ge(target, 7)]\n",
    "\n",
    "bad_mean = torch.mean(bad_data, dim=0)\n",
    "mid_mean = torch.mean(mid_data, dim=0)\n",
    "good_mean = torch.mean(good_data, dim=0)\n",
    "\n",
    "for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):\n",
    "    print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i, *args))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 根据阈值进行预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(2727))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "total_sulfur_threshold = 141.83\n",
    "total_sulfur_data = data[:,6]\n",
    "predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)\n",
    "predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([4898]), torch.bool, tensor(3258))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actual_indexes = torch.gt(target, 5)\n",
    "actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估预测好坏"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2018, 0.74000733406674, 0.6193984039287906)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_matches = torch.sum(actual_indexes & predicted_indexes).item()\n",
    "n_predicted = torch.sum(predicted_indexes).item()\n",
    "n_actual = torch.sum(actual_indexes).item()\n",
    "n_matches, n_matches / n_predicted, n_matches / n_actual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
